
import pickle
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import csv
from torch.utils.data import DataLoader, Dataset



model_class = input("Input the target group. (A or B) :")
model_number = input("Input the original (target) model number. (1-8) : ")
query= input("How many queries do you want to? (1500, 3000, 4500, 6500, 7500) :")
save_file=input("Save result csv file name : ")


if model_class=='a':
    model_class="A"
elif model_class=='b':
    model_class="B"

device = torch.device("cpu")

def load_object(file_path):
    with open(file_path, 'rb') as file:
        obj = pickle.load(file)
    return obj
loaded_img_data = load_object('../train_img_128.pkl')
class CNN1(nn.Module):
    def __init__(self):
        super(CNN1, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 16 * 16, 64)
        self.fc2 = nn.Linear(64, 2)  #
        self.to(device)


    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 16 * 16)
        _ = self.fc1(x)
        x = torch.relu(_)
        x = self.fc2(x)
        return x, _


class CustomDataset(Dataset):
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]

        if self.transform:
            image = self.transform(image)

        return image, target


transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])



model=CNN1()
model.load_state_dict(torch.load('model'+model_class+'_'+str(model_number)+'.pth'))

f = open('simul_meta_'+model_class+'_'+str(query)+'.csv', 'r', encoding='utf-8')
rdr = csv.reader(f)
report=[]
for line in rdr:
    if model_class=="A":
        if line[2]=='A':

            img1=loaded_img_data[line[5]]
            img2 =loaded_img_data[line[6]]
            img3 =loaded_img_data[line[7]]

            _,z_img1=model(transform(img1))
            _,z_img2 = model(transform(img2))
            _,z_img3 = model(transform(img3))

            euclidean_distance1 = torch.norm(z_img1 - z_img2).item()
            euclidean_distance2 = torch.norm(z_img2 - z_img3).item()
            euclidean_distance3 = torch.norm(z_img1 - z_img3).item()

            report.append(
                [line[0], line[1], line[2], line[3], line[4], line[5], line[6], line[7], line[8], line[9],"",
                 str(euclidean_distance1)[:5], str(euclidean_distance2)[:5], str(euclidean_distance3)[:5], "", "", ""])


        else:
            report.append([line[0],line[1],line[2],line[3],line[4],line[5],line[6],line[7],line[8],line[9],"","","","","","",""])

    else:
        if line[2] == 'B':

            img1 = loaded_img_data[line[5]]
            img2 = loaded_img_data[line[6]]
            img3 = loaded_img_data[line[7]]

            _, z_img1 = model(transform(img1))
            _, z_img2 = model(transform(img2))
            _, z_img3 = model(transform(img3))

            euclidean_distance1 = torch.norm(z_img1 - z_img2).item()
            euclidean_distance2 = torch.norm(z_img2 - z_img3).item()
            euclidean_distance3 = torch.norm(z_img1 - z_img3).item()

            report.append(
                [line[0], line[1], line[2], line[3], line[4], line[5], line[6], line[7], line[8], line[9], line[10],
                 str(euclidean_distance1)[:5], str(euclidean_distance2)[:5], str(euclidean_distance3)[:5], line[14],
                 line[15], line[16]])


        else:
            report.append(
                [line[0], line[1], line[2], line[3], line[4], line[5], line[6], line[7], line[8], line[9], line[10],
                 line[11], line[12], line[13], line[14], line[15], line[16]])

with open(save_file+'.csv', 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    for row in report:
        csvwriter.writerow(row)
